library(magrittr)
library(tidyverse)
library(Seurat)
library(readxl)
library(cowplot)
library(colorblindr)
library(viridis)
library(magick, lib.loc = "/home/uhlitzf/miniconda3/lib/R/library")
library(ggpubr)
## load global vars: 
source("_src/global_vars.R")

# meta_tbl
# clrs
# markers_v7
# markers_v7_super
# cell_type_super_lookup

names(clrs$cell_type) <- str_replace_all(names(clrs$cell_type), "\\.", " ")
names(clrs$cell_type) <- str_replace_all(names(clrs$cell_type), "Ovarian", "Ov")

## load data --------------------------------------

## load cohort embeddings
seu_tbl <- read_tsv("/work/shah/uhlitzf/data/SPECTRUM/freeze/v7/outs_pre/cells.tsv") %>% 
  mutate(cell_type = ifelse(cell_type == "Monocyte", "Myeloid.cell", cell_type)) %>% 
  mutate(cell_type = ifelse(cell_type == "Ovarian.cancer.cell", "Ov.cancer.cell", cell_type))

## load consensus data
consOV_tbl <- read_tsv("/work/shah/uhlitzf/data/SPECTRUM/freeze/v7/consensusOV/SPECTRUM_freeze_v7_consensusOV.tsv") %>%
  mutate(consensusOV = ordered(consensusOV, levels = names(clrs$consensusOV))) %>% 
  select(cell_id, consensusOV)

## join data
seu_tbl_full <- seu_tbl %>% 
  left_join(meta_tbl, by = "sample") %>% 
  filter(therapy == "pre-Rx", cell_type != "Other", tumor_supersite != "Unknown") %>% 
  rename(UMAP_1 = umap50_1, UMAP_2 = umap50_2) %>% 
  mutate(cell_type_super = cell_type_super_lookup[cell_type]) %>% 
  mutate(cell_type = ordered(str_replace_all(cell_type, "\\.", " "), 
                             levels = names(clrs$cell_type))) %>% 
  mutate(sort_short_x = ifelse(sort_short == "U" & cell_type_super == "Immune", 
                               "CD45+", ifelse(sort_short == "U" & cell_type_super == "Stromal", 
                                               "CD45-", sort_short))) %>% 
  left_join(consOV_tbl, by = "cell_id")

# seu_tbl_full <- seu_tbl_full %>%
#   sample_n(10000)

1 scRNA compositions

source("_src/comp_plot.R")
# rank_by()

add_helper_columns <- . %>% 
  mutate(label_supersite = "Site",
         label_therapy = "Rx",
         label_mutsig = "Signature")

comp_tbl_sample <- seu_tbl_full %>%
  filter(therapy == "pre-Rx", cell_type != "Other") %>%
  group_by(tumor_subsite, tumor_supersite, tumor_megasite, patient_id_short,
           therapy, sort_short_x, consensus_signature, cell_type) %>%
  tally() %>%
  group_by(tumor_subsite, tumor_supersite, tumor_megasite, patient_id_short,
           therapy, sort_short_x, consensus_signature) %>%
  mutate(nrel = n/sum(n)*100) %>% 
  add_helper_columns %>% 
  mutate(tumor_supersite = ordered(tumor_supersite, levels = rev(names(clrs$tumor_supersite)))) %>%
  mutate(sample_id = paste(tumor_subsite, patient_id_short, therapy, sort_short_x)) %>%
  group_by(sample_id) %>% 
  mutate(ntotal = sum(n)) %>% 
  filter(ntotal > 0) %>% 
  ungroup()


comp_tbl_consOV <- seu_tbl_full %>%
  filter(therapy == "pre-Rx", cell_type != "Other") %>%
  group_by(tumor_subsite, tumor_supersite, tumor_megasite, patient_id_short,
           therapy, sort_short_x, consensus_signature, consensusOV) %>%
  tally %>%
  group_by(tumor_subsite, tumor_supersite, tumor_megasite, patient_id_short,
           therapy, sort_short_x, consensus_signature) %>%
  mutate(nrel = n/sum(n)*100,
         log10n = log10(n)) %>% 
  add_helper_columns %>% 
  mutate(tumor_supersite = ordered(tumor_supersite, levels = rev(names(clrs$tumor_supersite)))) %>%
  mutate(sample_id = paste(tumor_subsite, patient_id_short, therapy, sort_short_x)) %>%
  ungroup


# 
# g9 <- default_comp_grid(filter(comp_tbl_consOV, sort_short_x == "CD45-"),
#                         consensusOV, "Immunoreactive")
# g10 <- default_comp_grid(filter(comp_tbl_consOV, sort_short_x == "CD45-"),
#                          consensusOV, "Mesenchymal")
# g11 <- default_comp_grid(filter(comp_tbl_consOV, sort_short_x == "CD45-"),
#                          consensusOV, "Differentiated")
# g12 <- default_comp_grid(filter(comp_tbl_consOV, sort_short_x == "CD45-"),
#                          consensusOV, "Proliferative")
# 
# g13 <- default_comp_grid(filter(comp_tbl_consOV, sort_short_x == "CD45+"),
#                          consensusOV, "Immunoreactive")
# g14 <- default_comp_grid(filter(comp_tbl_consOV, sort_short_x == "CD45+"),
#                          consensusOV, "Mesenchymal")
# g15 <- default_comp_grid(filter(comp_tbl_consOV, sort_short_x == "CD45+"),
#                          consensusOV, "Differentiated")
# g16 <- default_comp_grid(filter(comp_tbl_consOV, sort_short_x == "CD45+"),
#                          consensusOV, "Proliferative")

# pdf("_fig/002_cohort/002_comp_full.pdf", width = 3.5, height = 12)
# g1;g2;g3;g4;g5;g6;g7;g8;g9;g10;g11;g12;g13;g14;g15;g16
# dev.off()

1.1 cell type

1.1.1 Site

cell_types_immune <- str_replace_all(names(cell_type_super_lookup[cell_type_super_lookup=="Immune"]), "\\.", " ")
cell_types_stromal <- str_replace_all(names(cell_type_super_lookup[cell_type_super_lookup=="Stromal"]), "\\.", " ") %>% 
  str_replace_all("Ovarian", "Ov")

for(i in 1:length(cell_types_immune)){
  
  cat('#### ', cell_types_immune[i],' \n')
  
  plist <- default_comp_grid_list(
    filter(comp_tbl_sample, sort_short_x == "CD45+"), 
    cell_type, cell_types_immune[i], cell_type, mutsig_box = F)
  
  p <- plot_grid(plotlist = plist, ncol = 1, align = "v",
                 rel_heights = c(0.13, 0.13, 0.13, 0.61))
  
  print(p)
  
  cat(' \n \n')
  
}

1.1.1.1 B cell

1.1.1.2 Plasma cell

1.1.1.3 T cell

1.1.1.4 Myeloid cell

1.1.1.5 Mast cell

1.1.1.6 Dendritic cell

for(i in 1:length(cell_types_stromal)){
  
  cat('#### ', cell_types_stromal[i],' \n')
  
  plist <- default_comp_grid_list(
    filter(comp_tbl_sample, sort_short_x == "CD45-"), 
    cell_type, cell_types_stromal[i], cell_type, mutsig_box = F)
  
  p <- plot_grid(plotlist = plist, ncol = 1, align = "v",
                 rel_heights = c(0.13, 0.13, 0.13, 0.61))
  
  print(p)
  
  cat(' \n \n')

}

1.1.1.7 Endothelial cell

1.1.1.8 Fibroblast

1.1.1.9 Ov cancer cell

1.1.1.10 Ov cancer cell

1.1.2 Signature

cell_types_immune <- str_replace_all(names(cell_type_super_lookup[cell_type_super_lookup=="Immune"]), "\\.", " ")
cell_types_stromal <- str_replace_all(names(cell_type_super_lookup[cell_type_super_lookup=="Stromal"]), "\\.", " ") %>% 
  str_replace_all("Ovarian", "Ov")
  
for(i in 1:length(cell_types_immune)){
  
  cat('#### ', cell_types_immune[i],' \n')
  
  plist <- default_comp_grid_list(
    filter(comp_tbl_sample, sort_short_x == "CD45+"), 
    cell_type, cell_types_immune[i], cell_type, site_box = F)
  
  p <- plot_grid(plotlist = plist, ncol = 1, align = "v",
                 rel_heights = c(0.13, 0.13, 0.13, 0.61))
  
  print(p)
  
  cat(' \n \n')
  
}

1.1.2.1 B cell

1.1.2.2 Plasma cell

1.1.2.3 T cell

1.1.2.4 Myeloid cell

1.1.2.5 Mast cell

1.1.2.6 Dendritic cell

for(i in 1:length(cell_types_stromal)){
  
  cat('#### ', cell_types_stromal[i],' \n')
  
  plist <- default_comp_grid_list(
    filter(comp_tbl_sample, sort_short_x == "CD45-"), 
    cell_type, cell_types_stromal[i], cell_type, site_box = F)
  
  p <- plot_grid(plotlist = plist, ncol = 1, align = "v",
                 rel_heights = c(0.13, 0.13, 0.13, 0.61))
  
  print(p)
  
  cat(' \n \n')

}

1.1.2.7 Endothelial cell

1.1.2.8 Fibroblast

1.1.2.9 Ov cancer cell

1.1.2.10 Ov cancer cell

1.2 TCGA

1.2.1 Site

tcga_subtypes <- as.character(unique(consOV_tbl$consensusOV))

for(i in 1:length(tcga_subtypes)){
  
  cat('#### ', tcga_subtypes[i],' \n')
  
  plist <- default_comp_grid_list(
    filter(comp_tbl_consOV, sort_short_x == "CD45+"), 
    consensusOV, tcga_subtypes[i], consensusOV, mutsig_box = F)
  
  p <- plot_grid(plotlist = plist, ncol = 1, align = "v",
                 rel_heights = c(0.13, 0.13, 0.13, 0.61))
  
  print(p)
  
  cat(' \n \n')
  
}

1.2.1.1 Differentiated

1.2.1.2 Immunoreactive

1.2.1.3 Mesenchymal

1.2.1.4 Proliferative

for(i in 1:length(tcga_subtypes)){
  
  cat('#### ', tcga_subtypes[i],' \n')
  
  plist <- default_comp_grid_list(
    filter(comp_tbl_consOV, sort_short_x == "CD45-"), 
    consensusOV, tcga_subtypes[i], consensusOV, mutsig_box = F)
  
  p <- plot_grid(plotlist = plist, ncol = 1, align = "v",
                 rel_heights = c(0.13, 0.13, 0.13, 0.61))
  
  print(p)
  
  cat(' \n \n')

}

1.2.1.5 Differentiated

1.2.1.6 Immunoreactive

1.2.1.7 Mesenchymal

1.2.1.8 Proliferative

1.2.2 Signature

tcga_subtypes <- as.character(unique(consOV_tbl$consensusOV))

for(i in 1:length(tcga_subtypes)){
  
  cat('#### ', tcga_subtypes[i],' \n')
  
  plist <- default_comp_grid_list(
    filter(comp_tbl_consOV, sort_short_x == "CD45+"), 
    consensusOV, tcga_subtypes[i], consensusOV, site_box = F)
  
  p <- plot_grid(plotlist = plist, ncol = 1, align = "v",
                 rel_heights = c(0.13, 0.13, 0.13, 0.61))
  
  print(p)
  
  cat(' \n \n')
  
}

1.2.2.1 Differentiated

1.2.2.2 Immunoreactive

1.2.2.3 Mesenchymal

1.2.2.4 Proliferative

for(i in 1:length(tcga_subtypes)){
  
  cat('#### ', tcga_subtypes[i],' \n')
  
  plist <- default_comp_grid_list(
    filter(comp_tbl_consOV, sort_short_x == "CD45-"), 
    consensusOV, tcga_subtypes[i], consensusOV, site_box = F)
  
  p <- plot_grid(plotlist = plist, ncol = 1, align = "v",
                 rel_heights = c(0.13, 0.13, 0.13, 0.61))
  
  print(p)
  
  cat(' \n \n')

}

1.2.2.5 Differentiated

1.2.2.6 Immunoreactive

1.2.2.7 Mesenchymal

1.2.2.8 Proliferative

1.3 cell state

cois <- c("T.super", "Myeloid.super", "Fibroblast.super", "Ovarian.cancer.super")

read_comp_wrapper <- function(x, cluster_column) {
  cluster_column <- enquo(cluster_column)
  read_tsv(paste0("/work/shah/uhlitzf/data/SPECTRUM/freeze/v7/", x, "_subtype_compositions.tsv")) %>% 
    mutate(sort_short_x = ifelse(str_detect(sample_id_x, "CD45\\+"), "CD45+", "CD45-")) %>%
    complete(sample, !!cluster_column, fill = list(n = 0, nrel = 0)) %>% 
    left_join(meta_tbl, by = "sample") %>% 
    arrange(sample, sample_id_x, !!cluster_column) %>% 
    group_by(sample) %>% 
    mutate(sample_id_x = ifelse(is.na(sample_id_x), sample_id_x[1], sample_id_x)) %>% 
    mutate(sort_short_x = ifelse(is.na(sort_short_x), sort_short_x[1], sort_short_x)) %>% 
    mutate(nsum = ifelse(is.na(nsum), nsum[1], nsum)) %>% 
    ungroup %>% 
    mutate(sample_id = sample_id_x)
}

comp_list <- lapply(cois[-1], read_comp_wrapper, cluster_label) %>% setNames(cois[-1])
comp_list$T.super <- read_comp_wrapper("T.super", cluster_label_sub)


# foo <- comp_list$T.super %>% 
#   mutate(sample_id = paste(tumor_subsite, patient_id_short, therapy, sort_short_x)) %>%
#   mutate(Treg = str_detect(cluster_label_sub, "CD4.T.reg")) %>% 
#   select(sample, sample_id, cluster_label_sub, n, nrel, nsum, Treg) %>% 
#   group_by(Treg, sample, sample_id) %>% 
#   mutate(n_Treg = ifelse(Treg, sum(n), NA), nrel_Treg_T = ifelse(Treg, sum(nrel), NA)) %>% 
#   na.omit %>% 
#   arrange(desc(nrel_Treg_T)) %>% 
#   mutate(nrel = signif(nrel, 2), nrel_Treg_T = signif(nrel_Treg_T, 2)) %>% 
#   ungroup %>% 
#   distinct(sample, sample_id, n_Treg, nrel_Treg_T) %>% 
#   left_join(comp_tbl_sample %>% select(n_cd45p = n, sample_id), by = "sample_id") %>% 
#   group_by(sample, sample_id, n_Treg, nrel_Treg_T) %>% 
#   summarise(n_cd45p = sum(n_cd45p)) %>% 
#   mutate(nrel_Treg_cd45p = signif(n_Treg/n_cd45p*100, 2)) %>% 
#   ungroup %>% 
#   arrange(-nrel_Treg_cd45p) %>% 
#   select(-sample_id) %>% 
#   mutate(sample_mpif = str_remove_all(sample, "_CD45P_")) %>% 
#   left_join(select(mpif_meta_tbl, sample_mpif = sample_id, has_mpif = batch), by = "sample_mpif") %>% 
#   mutate(has_mpif = ifelse(is.na(has_mpif), F, T)) %>% 
#   select(-sample_mpif)
# 
# ggplot(foo, aes(nrel_Treg_T, nrel_Treg_cd45p, color = has_mpif)) +
#   geom_point() +
#   theme(aspect.ratio = 1)
# 
# write_tsv(foo, "treg_fractions.tsv")
# 
# mpif_meta_tbl$sample_id
# 
# foo %>% 
#   filter(has_mpif == T) %>% 
#   filter(str_detect(sample, "ADNEXA|OVARY|OMENTUM")) %>% 
#   write_tsv("treg_fractions_sub.tsv")

1.3.1 T/NK cells

T.super.clusters <- sort(unique(comp_list$T.super$cluster_label_sub))

for(i in 1:length(T.super.clusters)){
  
  cat('#### ', T.super.clusters[i],' \n')
  
  plist <- default_comp_grid_list(
    filter(comp_list$T.super, sort_short_x == "CD45+"), 
    cluster_label_sub, T.super.clusters[i], cluster_label_sub, 
    super_type_sub = "T.super")
  
  p <- plot_grid(plotlist = plist, ncol = 1, align = "v",
                 rel_heights = c(0.14, 0.14, 0.14, 0.14, 0.44))
  
  print(p)
  
  cat(' \n \n')
  
}

1.3.1.1 CD4.T.dysfunc.early

1.3.1.2 CD4.T.dysfunc.late.1

1.3.1.3 CD4.T.dysfunc.late.2

1.3.1.4 CD4.T.effector.memory

1.3.1.5 CD4.T.ISG

1.3.1.6 CD4.T.naive.centr.mem.1

1.3.1.7 CD4.T.naive.centr.mem.2

1.3.1.8 CD4.T.reg.1

1.3.1.9 CD4.T.reg.2

1.3.1.10 CD4.T.reg.3

1.3.1.11 CD4.T.reg.ISG

1.3.1.12 CD4.Th17.1

1.3.1.13 CD4.Th17.2

1.3.1.14 CD4.Th17.3

1.3.1.15 CD8.T.cytotoxic

1.3.1.16 CD8.T.dysfunc.early

1.3.1.17 CD8.T.dysfunc.ISG

1.3.1.18 CD8.T.dysfunc.late

1.3.1.19 CD8.T.effector.memory

1.3.1.20 CD8.T.ISG.early

1.3.1.21 CD8.T.ISG.late

1.3.1.22 CD8.T.naive.centr.mem

1.3.1.23 Cycling.CD4.T

1.3.1.24 Cycling.CD8.T.1

1.3.1.25 Cycling.CD8.T.2

1.3.1.26 Cycling.CD8.T.3

1.3.1.27 Cycling.CD8.T.4

1.3.1.28 Cycling.NK.1

1.3.1.29 Cycling.NK.2

1.3.1.30 Cycling.NK.3

1.3.1.31 gd.T.cell

1.3.1.32 NK.cytotoxic.GZMH

1.3.1.33 NK.cytotoxic.SPON2.1

1.3.1.34 NK.cytotoxic.SPON2.2

1.3.1.35 NK.reg.CCL3

1.3.1.36 NK.reg.CD56

1.3.1.37 NK.reg.CRTAM

1.3.1.38 NK.reg.IGFBP2

1.3.1.39 NK.reg.ISG

1.3.1.40 NK.reg.KRT81.KRT86.1

1.3.1.41 NK.reg.KRT81.KRT86.2

1.3.2 Myeloid cells

Myeloid.super.clusters <- sort(unique(comp_list$Myeloid.super$cluster_label))

for(i in 1:length(Myeloid.super.clusters)){
  
  cat('#### ', Myeloid.super.clusters[i],' \n')
  
  plist <- default_comp_grid_list(
    filter(comp_list$Myeloid.super, sort_short_x == "CD45+"), 
    cluster_label, Myeloid.super.clusters[i], cluster_label, 
    super_type = "Myeloid.super")
  
  p <- plot_grid(plotlist = plist, ncol = 1, align = "v",
                 rel_heights = c(0.14, 0.14, 0.14, 0.14, 0.44))
  
  print(p)
  
  cat(' \n \n')
  
}

1.3.2.1 cDC1

1.3.2.2 cDC2

1.3.2.3 Clearing.M

1.3.2.4 Cycling.M

1.3.2.5 M1.S100A8

1.3.2.6 M2.CXCL10

1.3.2.7 M2.ECM.1

1.3.2.8 M2.ECM.2

1.3.2.9 M2.MARCO

1.3.2.10 M2.SELENOP

1.3.2.11 Mast.cell

1.3.2.12 mDC

1.3.2.13 pDC

1.3.3 Fibroblasts

Fibroblast.super.clusters <- sort(unique(comp_list$Fibroblast.super$cluster_label))

for(i in 1:length(Fibroblast.super.clusters)){
  
  cat('#### ', Fibroblast.super.clusters[i],' \n')
  
  plist <- default_comp_grid_list(
    filter(comp_list$Fibroblast.super, sort_short_x == "CD45-"), 
    cluster_label, Fibroblast.super.clusters[i], cluster_label, 
    super_type = "Fibroblast.super")
  
  p <- plot_grid(plotlist = plist, ncol = 1, align = "v",
                 rel_heights = c(0.14, 0.14, 0.14, 0.14, 0.44))
  
  print(p)
  
  cat(' \n \n')
  
}

1.3.3.1 Activated.CAF.IGF1

1.3.3.2 Activated.CAF.ISG

1.3.3.3 Activated.CAF.TGFb

1.3.3.4 Angiogenic.CAF

1.3.3.5 Cycling.CAF

1.3.3.6 Early.CAF.1

1.3.3.7 Early.CAF.2

1.3.3.8 Mesothelial.CAF.IL1

1.3.3.9 Pericyte

1.3.4 Cancer cells

Cancer.super.clusters <- sort(unique(comp_list$Ovarian.cancer.super$cluster_label))

for(i in 1:length(Cancer.super.clusters)){
  
  cat('#### ', Cancer.super.clusters[i],' \n')
  
  plist <- default_comp_grid_list(
    filter(comp_list$Ovarian.cancer.super, sort_short_x == "CD45-"), 
    cluster_label, Cancer.super.clusters[i], cluster_label, 
    super_type = "Ovarian.cancer.super")
  
  p <- plot_grid(plotlist = plist, ncol = 1, align = "v",
                 rel_heights = c(0.14, 0.14, 0.14, 0.14, 0.44))
  
  print(p)
  
  cat(' \n \n')
  
}

1.3.4.1 Cancer.cell.1

1.3.4.2 Cancer.cell.2

1.3.4.3 Cancer.cell.3

1.3.4.4 Cancer.cell.4

1.3.4.5 Cancer.cell.5

1.3.4.6 Cancer.cell.6

1.3.4.7 Ciliated.cell.1

1.3.4.8 Ciliated.cell.2

1.3.4.9 Cycling.cancer.cell.1

1.3.4.10 Cycling.cancer.cell.2

2 mpIF composition

## meta data for mpIF
source("_src/global_vars.R")

names(clrs$cell_type) <- str_replace_all(names(clrs$cell_type), "\\.", " ")
names(clrs$cell_type) <- str_replace_all(names(clrs$cell_type), "Ovarian", "Ov")

# mpif_pixel <- read_tsv("/work/shah/vazquezi/data/transfers/spectrum/results/mpif/v10/integrate/outputs/cohort_merge/patient/SPECTRUM/all/detection.tsv") %>%
#   select(cell_id, compartment = Parent)
#  
# mpif_cell_type <- read_tsv("/work/shah/vazquezi/data/transfers/spectrum/results/mpif/v10/integrate/outputs/cohort_merge/patient/SPECTRUM/all/cell_type_manual.tsv") %>%
#   left_join(mpif_pixel, by = "cell_id") %>%
#   mutate(cell_id = str_remove_all(cell_id, "CD68.TOX.PD1.PDL1.CD8.panCK_CK8-18.DAPI_|_component_data - resolution #1")) %>%
#   separate(cell_id, into = c("patient_id", "tumor_subsite", "fov_id", "cell_idx"),
#            sep = "_", remove = F) %>%
#   mutate(fov_id = paste0(patient_id, "_", tumor_subsite, "_", fov_id),
#          slide_id = paste0(patient_id, "_", tumor_subsite))
# 
# ## expand for double and triple positive cells (cell type markers)
# mpif_ncell <- mpif_cell_type %>%
#   # sample_n(10000) %>%
#   mutate(CD68_state_log = CD68_state == "CD68+",
#          CD8_state_log = CD8_state == "CD8+",
#          panCK_state_log = panCK_state == "panCK+") %>%
#   group_by(cell_id) %>%
#   mutate(n_cells = sum(CD68_state_log, CD8_state_log, panCK_state_log)) %>%
#   select(cell_id, n_cells) %>%
#   deframe
# 
# mpif_cell_type_expanded <- bind_rows(
#   filter(mpif_cell_type, cell_id %in% names(mpif_ncell[mpif_ncell == 0]),
#          TOX_state == "TOX+" | PD1_state == "PD1+" | PDL1_state == "PDL1+") %>%
#     mutate(cell_id = paste0(cell_id, "_0"),
#            cell_type = c("Other")),
#   filter(mpif_cell_type, cell_id %in% names(mpif_ncell[mpif_ncell == 0]),
#          TOX_state == "TOX-" & PD1_state == "PD1-" & PDL1_state == "PDL1-") %>%
#     mutate(cell_id = paste0(cell_id, "_0"),
#            cell_type = c("Unknown")),
#   filter(mpif_cell_type, cell_id %in% names(mpif_ncell[mpif_ncell == 1]),
#          CD8_state == "CD8+") %>%
#     mutate(cell_id = paste0(cell_id, "_1"),
#            cell_type = c("CD8+")),
#   filter(mpif_cell_type, cell_id %in% names(mpif_ncell[mpif_ncell == 1]),
#          CD68_state == "CD68+") %>%
#     mutate(cell_id = paste0(cell_id, "_2"),
#            cell_type = c("CD68+")),
#   filter(mpif_cell_type, cell_id %in% names(mpif_ncell[mpif_ncell == 1]),
#          panCK_state == "panCK+") %>%
#     mutate(cell_id = paste0(cell_id, "_3"),
#            cell_type = c("panCK+")),
#   filter(mpif_cell_type, cell_id %in% names(mpif_ncell[mpif_ncell == 2]),
#          CD8_state == "CD8+") %>%
#     mutate(cell_id = paste0(cell_id, "_1"),
#            cell_type = c("CD8+")),
#   filter(mpif_cell_type, cell_id %in% names(mpif_ncell[mpif_ncell == 2]),
#          CD68_state == "CD68+") %>%
#     mutate(cell_id = paste0(cell_id, "_2"),
#            cell_type = c("CD68+")),
#   filter(mpif_cell_type, cell_id %in% names(mpif_ncell[mpif_ncell == 2]),
#          panCK_state == "panCK+") %>%
#     mutate(cell_id = paste0(cell_id, "_3"),
#            cell_type = c("panCK+")),
#   filter(mpif_cell_type, cell_id %in% names(mpif_ncell[mpif_ncell == 3]),
#          CD8_state == "CD8+") %>%
#     mutate(cell_id = paste0(cell_id, "_1"),
#            cell_type = c("CD8+")),
#   filter(mpif_cell_type, cell_id %in% names(mpif_ncell[mpif_ncell == 3]),
#          CD68_state == "CD68+") %>%
#     mutate(cell_id = paste0(cell_id, "_2"),
#            cell_type = c("CD68+")),
#   filter(mpif_cell_type, cell_id %in% names(mpif_ncell[mpif_ncell == 3]),
#          panCK_state == "panCK+") %>%
#     mutate(cell_id = paste0(cell_id, "_3"),
#            cell_type = c("panCK+"))
# )
# 
# mpif_cell_state_expanded <- mpif_cell_type_expanded %>%
#   mutate(cell_state = case_when(
#     cell_type == "CD8+" ~ paste0(CD8_state, TOX_state, PD1_state),
#     cell_type == "CD68+" ~ paste0(CD68_state, PDL1_state),
#     cell_type == "panCK+" ~ paste0(panCK_state, PDL1_state),
#     cell_type == "Unknown" ~ "Unknown",
#     cell_type == "Other" ~ "Other"
#   ))
# 
# write_tsv(mpif_cell_state_expanded, "/work/shah/uhlitzf/data/SPECTRUM/mpIF/cell_type_manual_expanded_v10.tsv")

mpif_cell_state_expanded <- read_tsv("/work/shah/uhlitzf/data/SPECTRUM/mpIF/cell_type_manual_expanded_v10.tsv") %>%
  select(cell_id, slide_id, fov_id, compartment, cell_type, cell_state, contains("state")) %>%
  left_join(select(mpif_meta_tbl, slide_id, patient_id_short, tumor_supersite, tumor_subsite, therapy, sample_id, consensus_signature, tumor_megasite), by = "slide_id") %>%
  na.omit() %>% 
  mutate(sort_short_x = compartment) %>% 
  mutate(tumor_supersite = ordered(tumor_supersite, levels = rev(names(clrs$tumor_supersite))))

## cell state composition fov lvl
mpif_cell_state_n <- mpif_cell_state_expanded %>%
  group_by(patient_id_short, tumor_supersite, tumor_subsite, tumor_megasite, therapy, sample_id, fov_id, consensus_signature, cell_state, cell_type) %>%
  tally() %>%
  group_by(patient_id_short, tumor_supersite, tumor_subsite, tumor_megasite, therapy, sample_id, fov_id, consensus_signature) %>%
  mutate(nrel = n/sum(n)) %>%
  ungroup() %>% 
  add_helper_columns

## cell state composition slide lvl
mpif_cell_state_n_slide <- mpif_cell_state_expanded %>%
  group_by(patient_id_short, tumor_supersite, tumor_subsite, tumor_megasite, therapy, sample_id, consensus_signature, cell_state, cell_type) %>%
  tally() %>%
  group_by(patient_id_short, tumor_supersite, tumor_subsite, tumor_megasite, therapy, sample_id, consensus_signature) %>%
  mutate(nrel = n/sum(n)) %>%
  ungroup() %>% 
  add_helper_columns

## cell state composition slide lvl compartment
mpif_cell_state_n_slide_compartment <- mpif_cell_state_expanded %>%
  group_by(patient_id_short, tumor_supersite, tumor_subsite, tumor_megasite, therapy, sample_id, consensus_signature, sort_short_x, cell_state, cell_type) %>%
  tally() %>%
  group_by(patient_id_short, tumor_supersite, tumor_subsite, tumor_megasite, therapy, sample_id, consensus_signature, sort_short_x) %>%
  mutate(nrel = n/sum(n)*100) %>%
  ungroup() %>% 
  add_helper_columns

## cell type composition slide lvl compartment
mpif_cell_type_n_slide_compartment <- mpif_cell_state_expanded %>%
  group_by(patient_id_short, tumor_supersite, tumor_subsite, tumor_megasite, therapy, sample_id, consensus_signature, sort_short_x, cell_type) %>%
  tally() %>%
  group_by(patient_id_short, tumor_supersite, tumor_subsite, tumor_megasite, therapy, sample_id, consensus_signature, sort_short_x) %>%
  mutate(nrel = n/sum(n)*100) %>%
  ungroup() %>% 
  add_helper_columns

2.1 cell type

2.1.1 Site

cell_types_mpif <- c("CD8+", "CD68+", "panCK+", "Other", "Unknown")
  
for(i in 1:length(cell_types_mpif)){
  
  cat('#### ', cell_types_mpif[i],' \n')
  
  plist_t <- default_comp_grid_list(
    filter(mpif_cell_state_n_slide_compartment, sort_short_x == "Tumor"), 
    cell_type, cell_types_mpif[i], cell_state, mutsig_box = F, 
    nmax = 250000)

  plist_s <- default_comp_grid_list(
    filter(mpif_cell_state_n_slide_compartment, sort_short_x == "Stroma"), 
    cell_type, cell_types_mpif[i], cell_state, mutsig_box = F, 
    nmax = 250000)
  
  pt <- plot_grid(plotlist = plist_t, ncol = 1, align = "v",
                  rel_heights = c(0.13, 0.13, 0.13, 0.61))
  
  ps <- plot_grid(plotlist = plist_s, ncol = 1, align = "v",
                  rel_heights = c(0.13, 0.13, 0.13, 0.61))
  
  p <- plot_grid(pt, ps, ncol = 2)
  
  print(p)
  
  cat(' \n \n')
  
}

2.1.1.1 CD8+

2.1.1.2 CD68+

2.1.1.3 panCK+

2.1.1.4 Other

2.1.1.5 Unknown

2.1.2 Signature

cell_types_mpif <- c("CD8+", "CD68+", "panCK+", "Other", "Unknown")
  
for(i in 1:length(cell_types_mpif)){
  
  cat('#### ', cell_types_mpif[i],' \n')
  
  plist_t <- default_comp_grid_list(
    filter(mpif_cell_state_n_slide_compartment, sort_short_x == "Tumor"), 
    cell_type, cell_types_mpif[i], cell_state, site_box = F, 
    nmax = 250000)

  plist_s <- default_comp_grid_list(
    filter(mpif_cell_state_n_slide_compartment, sort_short_x == "Stroma"), 
    cell_type, cell_types_mpif[i], cell_state, site_box = F, 
    nmax = 250000)
  
  pt <- plot_grid(plotlist = plist_t, ncol = 1, align = "v",
                  rel_heights = c(0.13, 0.13, 0.13, 0.61))
  
  ps <- plot_grid(plotlist = plist_s, ncol = 1, align = "v",
                  rel_heights = c(0.13, 0.13, 0.13, 0.61))
  
  p <- plot_grid(pt, ps, ncol = 2)
  
  print(p)
  
  cat(' \n \n')
  
}

2.1.2.1 CD8+

2.1.2.2 CD68+

2.1.2.3 panCK+

2.1.2.4 Other

2.1.2.5 Unknown

2.2 cell state

2.2.1 Site

cell_states_mpif <- names(clrs$cell_state)
  
for(i in 1:length(cell_states_mpif)){
  
  cat('#### ', cell_states_mpif[i],' \n')
  
  plist_t <- default_comp_grid_list(
    filter(mpif_cell_state_n_slide_compartment, sort_short_x == "Tumor"), 
    cell_state, cell_states_mpif[i], cell_state, mutsig_box = F, 
    nmax = 250000)

  plist_s <- default_comp_grid_list(
    filter(mpif_cell_state_n_slide_compartment, sort_short_x == "Stroma"), 
    cell_state, cell_states_mpif[i], cell_state, mutsig_box = F, 
    nmax = 250000)
  
  pt <- plot_grid(plotlist = plist_t, ncol = 1, align = "v",
                  rel_heights = c(0.13, 0.13, 0.13, 0.61))
  
  ps <- plot_grid(plotlist = plist_s, ncol = 1, align = "v",
                  rel_heights = c(0.13, 0.13, 0.13, 0.61))
  
  p <- plot_grid(pt, ps, ncol = 2)
  
  print(p)
  
  cat(' \n \n')
  
}

2.2.1.1 CD8+TOX-PD1-

2.2.1.2 CD8+TOX-PD1+

2.2.1.3 CD8+TOX+PD1-

2.2.1.4 CD8+TOX+PD1+

2.2.1.5 CD68+PDL1-

2.2.1.6 CD68+PDL1+

2.2.1.7 panCK+PDL1-

2.2.1.8 panCK+PDL1+

2.2.1.9 Other

2.2.1.10 Unknown

2.2.2 Signature

cell_states_mpif <- names(clrs$cell_state)

for(i in 1:length(cell_states_mpif)){
  
  cat('#### ', cell_states_mpif[i],' \n')
  
  plist_t <- default_comp_grid_list(
    filter(mpif_cell_state_n_slide_compartment, sort_short_x == "Tumor"), 
    cell_state, cell_states_mpif[i], cell_state, site_box = F, 
    nmax = 250000)

  plist_s <- default_comp_grid_list(
    filter(mpif_cell_state_n_slide_compartment, sort_short_x == "Stroma"), 
    cell_state, cell_states_mpif[i], cell_state, site_box = F, 
    nmax = 250000)
  
  pt <- plot_grid(plotlist = plist_t, ncol = 1, align = "v",
                  rel_heights = c(0.13, 0.13, 0.13, 0.61))
  
  ps <- plot_grid(plotlist = plist_s, ncol = 1, align = "v",
                  rel_heights = c(0.13, 0.13, 0.13, 0.61))
  
  p <- plot_grid(pt, ps, ncol = 2)
  
  print(p)
  
  cat(' \n \n')
  
}

2.2.2.1 CD8+TOX-PD1-

2.2.2.2 CD8+TOX-PD1+

2.2.2.3 CD8+TOX+PD1-

2.2.2.4 CD8+TOX+PD1+

2.2.2.5 CD68+PDL1-

2.2.2.6 CD68+PDL1+

2.2.2.7 panCK+PDL1-

2.2.2.8 panCK+PDL1+

2.2.2.9 Other

2.2.2.10 Unknown

3 mpIF x scRNA correlation

# seu_cohort <- read_rds("/work/shah/isabl_data_lake/analyses/68/75/6875/RNASCP/outs/integrated_seurat_final.rds")
# 
# scrna_markers <- as_tibble(cbind(cell_id = colnames(seu_cohort), FetchData(seu_cohort, c("cell_type", "sample", "CD68", "PDCD1", "CD274", "TOX", "CD8A", "CD8B", "KRT8", "KRT19")))) %>%
#   mutate(CD68_state = ifelse(CD68 > 0, "CD68+", "CD68-"),
#          CD8_state = ifelse(CD8A > 0 | CD8B > 0, "CD8+", "CD8-"),
#          panCK_state = ifelse(KRT8 > 0 | KRT19 > 0, "panCK+", "panCK-"),
#          TOX_state = ifelse(TOX > 0, "TOX+", "TOX-"),
#          PD1_state = ifelse(PDCD1 > 0, "PD1+", "PD1-"),
#          PDL1_state = ifelse(CD274 > 0, "PDL1+", "PDL1-"))
# 
# write_tsv(scrna_markers, "/work/shah/uhlitzf/data/SPECTRUM/freeze/v7/scrna_mpif_marker_expression.tsv")

scrna_markers <- read_tsv("/work/shah/uhlitzf/data/SPECTRUM/freeze/v7/scrna_mpif_marker_expression.tsv") %>%
  ## remove double positives
  filter(!(CD68_state == "CD68+" & CD8_state == "CD8+"),
         !(CD68_state == "CD68+" & panCK_state == "panCK+"),
         !(CD8_state == "CD8+" & panCK_state == "panCK+")) %>%
  mutate(cell_type_sc = ifelse(CD68_state == "CD68+", "CD68+", ifelse(CD8_state == "CD8+", "CD8+", ifelse(panCK_state == "panCK+", "panCK+", "Other"))),
         cell_state = case_when(
           cell_type_sc == "CD8+" ~ paste0(CD8_state, TOX_state, PD1_state),
           cell_type_sc == "CD68+" ~ paste0(CD68_state, PDL1_state),
           cell_type_sc == "panCK+" ~ paste0(panCK_state, PDL1_state),
           cell_type_sc == "Unknown" ~ "Unknown",
           cell_type_sc == "Other" ~ "Other"
         ))

markers_pos_frac_scrna_celltype <- seu_tbl_full %>%
  # select(cell_id, sample) %>%
  select(cell_id, sample, tumor_supersite, consensus_signature, patient_id_short, sort_short_x) %>%
  left_join(select(scrna_markers, cell_id, cell_type_sc, cell_state), by = "cell_id") %>%
  na.omit() %>%
  group_by(sample, cell_type_sc, tumor_supersite, consensus_signature, patient_id_short, sort_short_x) %>%
  tally %>%
  group_by(sample, tumor_supersite, consensus_signature, patient_id_short, sort_short_x) %>%
  mutate(nrel = n/sum(n)*100) %>%
  select(sample, cell_type_sc, n, nrel, everything()) %>%
  ungroup

markers_pos_frac_scrna_cellstate <- seu_tbl_full %>%
  # select(cell_id, sample) %>%
  select(cell_id, sample, tumor_supersite, consensus_signature, patient_id_short, sort_short_x) %>%
  left_join(select(scrna_markers, cell_id, cell_type_sc, cell_state), by = "cell_id") %>%
  na.omit() %>%
  group_by(sample, cell_state, cell_type_sc, tumor_supersite, consensus_signature, patient_id_short, sort_short_x) %>%
  tally %>%
  group_by(sample, cell_type_sc, tumor_supersite, consensus_signature, patient_id_short, sort_short_x) %>%
  mutate(nrel = n/sum(n)*100) %>%
  select(sample, cell_type_sc, cell_state, n, nrel, everything()) %>%
  ungroup

markers_pos_frac_scrna_gene <- seu_tbl_full %>%
  # select(cell_id, sample) %>%
  select(cell_id, sample, cell_type, tumor_supersite, consensus_signature, patient_id_short, sort_short_x) %>%
  left_join(select(scrna_markers, -sample, -cell_type), by = "cell_id") %>%
  na.omit() %>%
  select(-contains("state")) %>%
  mutate(cell_type = str_replace_all(cell_type, "\\.", " ")) %>%
  gather(gene, value, -cell_id, -sample, -cell_type, -tumor_supersite,
         -consensus_signature, -patient_id_short, -sort_short_x, -cell_type_sc) %>%
  mutate(gene_state = value > 0) %>%
  group_by(gene, sample, tumor_supersite, consensus_signature, patient_id_short, sort_short_x) %>%
  mutate(n = sum(gene_state),
         nrel = n/length(n)*100) %>%
  ungroup %>%
  distinct(gene, sample, tumor_supersite, consensus_signature, patient_id_short, sort_short_x, n, nrel) %>%
  select(sample, gene, n, nrel, everything())

3.1 CD45+ cell type

markers_pos_scrna_mpif <- mpif_cell_state_n_slide_compartment %>%
  select(sample = sample_id, cell_type_sc = cell_type,
         compartment = sort_short_x, n_mpif = n, nrel_mpif = nrel) %>%
  mutate(sample = str_replace_all(sample, "_S1", "_S1_CD45P_")) %>% 
  group_by(sample, cell_type_sc, compartment) %>%
  summarise(n_mpif = sum(n_mpif), nrel_mpif = sum(nrel_mpif)) %>%
  ungroup() %>%
  left_join(markers_pos_frac_scrna_celltype, by = c("sample", "cell_type_sc")) %>%
  filter(sort_short_x == "CD45+",
         cell_type_sc %in% c("CD8+", "CD68+"),
         compartment %in% c("Stroma", "Tumor"))

common_layers <- list(
  facet_wrap(cell_type_sc~compartment, scales = "free"),
  stat_smooth(aes(nrel, nrel_mpif), method = "lm", color = "black"),
  stat_cor(aes(nrel, nrel_mpif), method = "spearman", color = "black"),
  labs(x = "Fraction in scRNA (CD45+)",
       y = "Fraction in mpIF"),
  # coord_cartesian(ylim  = c(0, 100), xlim = c(0, 100)),
  theme(aspect.ratio = 1)
)

p1 <- ggplot(markers_pos_scrna_mpif) +
  geom_point(aes(nrel, nrel_mpif, color = tumor_supersite)) +
  scale_color_manual(values = clrs$tumor_supersite) +
  common_layers

p2 <- ggplot(markers_pos_scrna_mpif) +
  geom_point(aes(nrel, nrel_mpif, color = consensus_signature)) +
  scale_color_manual(values = clrs$consensus_signature) +
  common_layers

plot_grid(p1, p2, ncol = 2)

3.2 CD45+ cell state

markers_pos_scrna_mpif_state <- mpif_cell_state_n_slide_compartment %>%
  group_by(sample_id, sort_short_x, cell_type) %>%
  mutate(nrel_state = n/sum(n)*100) %>%
  select(sample = sample_id, cell_type_sc = cell_type,
         compartment = sort_short_x, 
         n_mpif = n, nrel_mpif = nrel_state, cell_state) %>%
  mutate(sample = str_replace_all(sample, "_S1", "_S1_CD45P_")) %>% 
  left_join(markers_pos_frac_scrna_cellstate, by = c("sample", "cell_type_sc", "cell_state")) %>%
  filter(sort_short_x == "CD45+",
         cell_state %in% c("CD8+TOX-PD1-", "CD8+TOX+PD1+", "CD68+PDL1-", "CD68+PDL1+"),
         compartment %in% c("Stroma", "Tumor"))

common_layers <- list(
  facet_wrap(cell_state~compartment, scales = "free", ncol = 4),
  geom_smooth(aes(nrel, nrel_mpif), method = "lm", color = "black"),
  stat_cor(aes(nrel, nrel_mpif), method = "spearman", color = "black"),
  labs(x = "Fraction in scRNA (CD45+)",
       y = "Fraction in mpIF"),
  # coord_cartesian(ylim  = c(0, 100), xlim = c(0, 100)),
  theme(aspect.ratio = 1)
)

p1 <- ggplot(markers_pos_scrna_mpif_state) +
  geom_point(aes(nrel, nrel_mpif, color = tumor_supersite)) +
  scale_color_manual(values = clrs$tumor_supersite) +
  common_layers

p2 <- ggplot(markers_pos_scrna_mpif_state) +
  geom_point(aes(nrel, nrel_mpif, color = consensus_signature)) +
  scale_color_manual(values = clrs$consensus_signature) +
  common_layers

plot_grid(p1, p2, ncol = 1)

3.3 CD45- cell type

markers_pos_scrna_mpif_cd45n <- mpif_cell_state_n_slide_compartment %>%
  select(sample = sample_id, cell_type_sc = cell_type,
         compartment = sort_short_x, n_mpif = n, nrel_mpif = nrel) %>%
  mutate(sample = str_replace_all(sample, "_S1", "_S1_CD45N_")) %>% 
  group_by(sample, cell_type_sc, compartment) %>%
  summarise(n_mpif = sum(n_mpif), nrel_mpif = sum(nrel_mpif)) %>%
  ungroup() %>%
  left_join(markers_pos_frac_scrna_celltype, by = c("sample", "cell_type_sc")) %>%
  filter(sort_short_x == "CD45-",
         cell_type_sc %in% c("panCK+"),
         compartment %in% c("Stroma", "Tumor"))

common_layers <- list(
  facet_wrap(cell_type_sc~compartment, scales = "free"),
  stat_smooth(aes(nrel, nrel_mpif), method = "lm", color = "black"),
  stat_cor(aes(nrel, nrel_mpif), method = "pearson", color = "black"),
  labs(x = "Fraction in scRNA (CD45-)",
       y = "Fraction in mpIF"),
  # coord_cartesian(ylim  = c(0, 100), xlim = c(0, 100)),
  theme(aspect.ratio = 1)
)

p1 <- ggplot(markers_pos_scrna_mpif_cd45n) +
  geom_point(aes(nrel, nrel_mpif, color = tumor_supersite)) +
  scale_color_manual(values = clrs$tumor_supersite) +
  common_layers

p2 <- ggplot(markers_pos_scrna_mpif_cd45n) +
  geom_point(aes(nrel, nrel_mpif, color = consensus_signature)) +
  scale_color_manual(values = clrs$consensus_signature) +
  common_layers

plot_grid(p1, p2)

3.4 CD45- cell state

markers_pos_scrna_mpif_state_cd45n <- mpif_cell_state_n_slide_compartment %>%
  group_by(sample_id, sort_short_x, cell_type) %>%
  mutate(nrel_state = n/sum(n)*100) %>%
  select(sample = sample_id, cell_type_sc = cell_type,
         compartment = sort_short_x, n_mpif = n, 
         nrel_mpif = nrel_state, cell_state) %>%
  mutate(sample = str_replace_all(sample, "_S1", "_S1_CD45N_")) %>% 
  left_join(markers_pos_frac_scrna_cellstate, by = c("sample", "cell_type_sc", "cell_state")) %>%
  filter(sort_short_x == "CD45-",
         cell_state %in% c("panCK+PDL1-", "panCK+PDL1+"),
         compartment %in% c("Stroma", "Tumor"))

common_layers <- list(
  facet_wrap(cell_state~compartment, scales = "free", ncol = 4),
  geom_smooth(aes(nrel, nrel_mpif), method = "lm", color = "black"),
  stat_cor(aes(nrel, nrel_mpif), method = "pearson", color = "black"),
  labs(x = "Fraction in scRNA (CD45-)",
       y = "Fraction in mpIF"),
  # coord_cartesian(ylim  = c(0, 100), xlim = c(0, 100)),
  theme(aspect.ratio = 1)
)

p1 <- ggplot(markers_pos_scrna_mpif_state_cd45n) +
  geom_point(aes(nrel, nrel_mpif, color = tumor_supersite)) +
  scale_color_manual(values = clrs$tumor_supersite) +
  common_layers

p2 <- ggplot(markers_pos_scrna_mpif_state_cd45n) +
  geom_point(aes(nrel, nrel_mpif, color = consensus_signature)) +
  scale_color_manual(values = clrs$consensus_signature) +
  common_layers

plot_grid(p1, p2, ncol = 1)

4 Main text figure

plist1 <- default_comp_grid_list(
  filter(comp_tbl_sample, sort_short_x == "CD45-"), 
  cell_type, "Ov cancer cell", cell_type)
plist1$empty <- ggdraw()

plist2 <- default_comp_grid_list(
  filter(comp_tbl_sample, sort_short_x == "CD45+"), 
  cell_type, "T cell", cell_type)
plist2$empty <- ggdraw()

plist3 <- default_comp_grid_list(
  filter(mpif_cell_state_n_slide_compartment, sort_short_x == "Tumor"), 
  cell_type, "CD8+", cell_state, nmax = 250000)
plist3$empty <- ggdraw()

plist4 <- default_comp_grid_list(
  filter(mpif_cell_state_n_slide_compartment, sort_short_x == "Stroma"), 
  cell_type, "CD8+", cell_state, nmax = 250000)
plist4$empty <- ggdraw()

plist5 <- default_comp_grid_list(
  filter(mpif_cell_type_n_slide_compartment, sort_short_x == "Tumor"), 
  cell_type, "CD8+", cell_type, nmax = 250000)
plist5$empty <- ggdraw()

plist6 <- default_comp_grid_list(
  filter(mpif_cell_type_n_slide_compartment, sort_short_x == "Stroma"), 
  cell_type, "CD8+", cell_type, nmax = 250000)
plist6$empty <- ggdraw()

## cell type grid incl mut sig
pcomp1 <- plot_grid(plotlist = plist1, ncol = 1, align = "v",
                    rel_heights = c(0.11, 0.11, 0.13, 0.13, 0.52, 0))

pcomp2 <- plot_grid(plotlist = plist2, ncol = 1, align = "v",
                    rel_heights = c(0.11, 0.11, 0.13, 0.13, 0.52, 0))

pcomp3 <- plot_grid(plotlist = plist5, ncol = 1, align = "v",
                    rel_heights = c(0.11, 0.11, 0.13, 0.13, 0.27, 0.25))

pcomp4 <- plot_grid(plotlist = plist6, ncol = 1, align = "v",
                    rel_heights = c(0.11, 0.11, 0.13, 0.13, 0.27, 0.25))

gcomp1 <- plot_grid(pcomp1, pcomp2, ggdraw(), pcomp3, pcomp4, ncol = 5,
                    rel_widths = c(0.245, 0.245, 0.02, 0.245, 0.245))

## cell state grid without mutsig
pcomp1 <- plot_grid(plotlist = plist1[-3], ncol = 1, align = "v",
                    rel_heights = c(0.13, 0.13, 0.16, 0.58, 0))

pcomp2 <- plot_grid(plotlist = plist2[-3], ncol = 1, align = "v",
                    rel_heights = c(0.13, 0.13, 0.16, 0.58, 0))

pcomp3 <- plot_grid(plotlist = plist3[-3], ncol = 1, align = "v",
                    rel_heights = c(0.13, 0.13, 0.16, 0.32, 0.26))

pcomp4 <- plot_grid(plotlist = plist4[-3], ncol = 1, align = "v",
                    rel_heights = c(0.13, 0.13, 0.16, 0.32, 0.26))

gcomp2 <- plot_grid(pcomp1, pcomp2, ggdraw(), pcomp3, pcomp4, ncol = 5, 
                    rel_widths = c(0.245, 0.245, 0.02, 0.245, 0.245))

## cell type grid without mutsig
pcomp1 <- plot_grid(plotlist = plist1[-3], ncol = 1, align = "v",
                    rel_heights = c(0.13, 0.13, 0.16, 0.58, 0))

pcomp2 <- plot_grid(plotlist = plist2[-3], ncol = 1, align = "v",
                    rel_heights = c(0.13, 0.13, 0.16, 0.58, 0))

pcomp3 <- plot_grid(plotlist = plist5[-3], ncol = 1, align = "v",
                    rel_heights = c(0.13, 0.13, 0.16, 0.32, 0.26))

pcomp4 <- plot_grid(plotlist = plist6[-3], ncol = 1, align = "v",
                    rel_heights = c(0.13, 0.13, 0.16, 0.32, 0.26))

gcomp3 <- plot_grid(pcomp1, pcomp2, ggdraw(), pcomp3, pcomp4, ncol = 5, 
                    rel_widths = c(0.245, 0.245, 0.02, 0.245, 0.245))


mpif_state_legend <- get_legend(plot_comp_bar(
  rank_by(filter(mpif_cell_state_n_slide_compartment, sort_short_x == "Tumor"),
          cell_type, "CD8+", cell_state),
  sample_id_lvl, nrel, cell_state, facet = "sort_short_x") +
  scale_color_manual(values = c(Adnexa = "#ff0000", Other = "#56B4E9"), labels = c("Enriched", "Depleted", "")) + 
  labs(fill = "mpIF cell state") +
  guides(fill = guide_legend(ncol = 2)))

mpif_cell_type_legend <- get_legend(plot_comp_bar(
  rank_by(filter(mpif_cell_type_n_slide_compartment, sort_short_x == "Tumor"),
          cell_type, "CD8+", cell_type),
  sample_id_lvl, nrel, cell_type, facet = "sort_short_x") +
  scale_color_manual(values = c(Adnexa = "#ff0000", Other = "#56B4E9"), labels = c("Enriched", "Depleted", "")) + 
  labs(fill = "mpIF cell state") +
  guides(fill = guide_legend(ncol = 1)))

vec_legend_helper <- plot_comp_vector(
  rank_by(filter(comp_tbl_sample, sort_short_x == "CD45+"),
          cell_type, "T cell", cell_type),
  sample_id_rank, patient_id_short,
  tumor_megasite, tumor_megasite, "Adnexa",
  cell_type, "T cell") +
  scale_color_manual(values = c(Adnexa = "#ff0000", Other = "#56B4E9"), labels = c("Enriched", "Depleted", "")) + 
  labs(color = "Non-adnexal\ninfiltration")

vec_legend_shape <- get_legend(vec_legend_helper + guides(color = F))
vec_legend_color <- get_legend(vec_legend_helper + guides(shape = F))
markers_pos_scrna_mpif_state_cd45p <- mpif_cell_state_n_slide_compartment %>%
  group_by(sample_id, sort_short_x, cell_type) %>%
  mutate(nrel_state = n/sum(n)*100) %>%
  select(sample = sample_id, cell_type_sc = cell_type,
         compartment = sort_short_x, 
         n_mpif = n, nrel_mpif = nrel_state, cell_state) %>%
  mutate(sample = str_replace_all(sample, "_S1", "_S1_CD45P_")) %>% 
  left_join(markers_pos_frac_scrna_cellstate, by = c("sample", "cell_type_sc", "cell_state")) %>%
  filter(sort_short_x == "CD45+",
         cell_state %in% c("CD8+TOX+PD1+", "CD68+PDL1+"),
         compartment %in% c("Stroma", "Tumor"))

markers_pos_scrna_mpif_state_cd45n <- mpif_cell_state_n_slide_compartment %>%
  group_by(sample_id, sort_short_x, cell_type) %>%
  mutate(nrel_state = n/sum(n)*100) %>%
  select(sample = sample_id, cell_type_sc = cell_type,
         compartment = sort_short_x, n_mpif = n, 
         nrel_mpif = nrel_state, cell_state) %>%
  mutate(sample = str_replace_all(sample, "_S1", "_S1_CD45N_")) %>% 
  left_join(markers_pos_frac_scrna_cellstate, by = c("sample", "cell_type_sc", "cell_state")) %>%
  filter(sort_short_x == "CD45-",
         cell_state %in% c("panCK+PDL1+"),
         compartment %in% c("Stroma", "Tumor"))

markers_pos_scrna_mpif_state <- bind_rows(markers_pos_scrna_mpif_state_cd45p, 
                                          markers_pos_scrna_mpif_state_cd45n) %>% 
  mutate(compartment = ordered(compartment, levels = c("Tumor", "Stroma"))) %>% 
  mutate(cell_state = case_when(
    cell_state == "panCK+PDL1+" ~ "panCK+PDL1+\npanCK+",
    cell_state == "CD8+TOX+PD1+" ~ "CD8+TOX+PD1+\nCD8+",
    cell_state == "CD68+PDL1+" ~ "CD68+PDL1+\nCD68+"
  ))

common_layers <- list(
  facet_grid(cell_state~compartment, scales = "free"),
  geom_smooth(aes(nrel, nrel_mpif), method = "lm", color = "black"),
  stat_cor(aes(nrel, nrel_mpif), method = "spearman", color = "black"),
  labs(x = "Fraction in scRNA [%]",
       y = "Fraction in mpIF [%]"),
  # coord_cartesian(ylim  = c(0, 100), xlim = c(0, 100)),
  theme(aspect.ratio = 1, plot.margin = margin(0, 0, 0, 0)),
  guides(color = F)
)

cor_plot_list <- list()

cor_plot_list$p1 <- ggplot(filter(markers_pos_scrna_mpif_state, 
                                  cell_state %in% c("panCK+PDL1+\npanCK+"))) +
  geom_point(aes(nrel, nrel_mpif, color = cell_type_sc)) +
  scale_color_manual(values = clrs$cell_type) +
  common_layers + 
  theme(axis.title.x = element_blank(),
        axis.title.y = element_blank(),
        strip.text.x = element_blank())

cor_plot_list$p2 <- ggplot(filter(markers_pos_scrna_mpif_state, 
                                  cell_state %in% c("CD68+PDL1+\nCD68+"))) +
  geom_point(aes(nrel, nrel_mpif, color = cell_type_sc)) +
  scale_color_manual(values = clrs$cell_type) +
  common_layers + 
  theme(axis.title.x = element_blank(),
        strip.text.x = element_blank())

cor_plot_list$p3 <- ggplot(filter(markers_pos_scrna_mpif_state, 
                                  cell_state %in% c("CD8+TOX+PD1+\nCD8+"))) +
  geom_point(aes(nrel, nrel_mpif, color = cell_type_sc)) +
  scale_color_manual(values = clrs$cell_type) +
  common_layers + 
  theme(axis.title.y = element_blank(),
        strip.text.x = element_blank())

cor_plot_grid <- plot_grid(plotlist = cor_plot_list, ncol = 1, align = "hv")
comp_grid_full <- ggdraw() +
  draw_plot(gcomp1, x = 0, y = 0, width = 1, height = 1) +
  draw_grob(vec_legend_shape, x = 0.54, y = 0.125, vjust = 0.5, hjust = 0) +
  draw_grob(vec_legend_color, x = 0.68, y = 0.115, vjust = 0.5, hjust = 0) +
  draw_grob(mpif_cell_type_legend, x = 0.84, y = 0.105, vjust = 0.5, hjust = 0)
  # draw_plot(cor_plot_grid, x = 0.68, y = 0.21, width = 0.35, height = 0.75) +
  # draw_label("Tumor", x = 0.815, y = 0.98) +
  # draw_label("Stroma", x = 0.91, y = 0.98)

comp_grid_full

ggsave(filename = "_fig/002_cohort/002_comp_grid_site_mutsig.pdf", comp_grid_full, 
       width = 12, height = 10)

ggsave(filename = "_fig/002_cohort/002_comp_grid_site_mutsig.png", comp_grid_full, 
       width = 12, height = 10)
comp_grid_full <- ggdraw() +
  draw_plot(gcomp2, x = 0, y = 0, width = 0.69, height = 1) +
  draw_grob(vec_legend_shape, x = 0.37, y = 0.13, vjust = 0.5, hjust = 0) +
  draw_grob(vec_legend_color, x = 0.47, y = 0.115, vjust = 0.5, hjust = 0) +
  draw_grob(mpif_state_legend, x = 0.60, y = 0.105, vjust = 0.5, hjust = 0) +
  draw_plot(cor_plot_grid, x = 0.68, y = 0.21, width = 0.35, height = 0.75) +
  draw_label("Tumor", x = 0.815, y = 0.98) +
  draw_label("Stroma", x = 0.91, y = 0.98)

comp_grid_full

ggsave(filename = "_fig/002_cohort/002_comp_grid_site.pdf", comp_grid_full, 
       width = 16, height = 8)

ggsave(filename = "_fig/002_cohort/002_comp_grid_site.png", comp_grid_full, 
       width = 16, height = 8)
comp_grid_full <- ggdraw() +
  draw_plot(gcomp3, x = 0, y = 0, width = 1, height = 1) +
  draw_grob(vec_legend_shape, x = 0.54, y = 0.125, vjust = 0.5, hjust = 0) +
  draw_grob(vec_legend_color, x = 0.68, y = 0.115, vjust = 0.5, hjust = 0) +
  draw_grob(mpif_cell_type_legend, x = 0.84, y = 0.105, vjust = 0.5, hjust = 0)
  # draw_plot(cor_plot_grid, x = 0.68, y = 0.21, width = 0.35, height = 0.75) +
  # draw_label("Tumor", x = 0.815, y = 0.98) +
  # draw_label("Stroma", x = 0.91, y = 0.98)

comp_grid_full

ggsave(filename = "_fig/002_cohort/002_comp_grid_site_cell_type.pdf", comp_grid_full, 
       width = 12, height = 8)

ggsave(filename = "_fig/002_cohort/002_comp_grid_site_cell_type.png", comp_grid_full, 
       width = 12, height = 8)

5 Supplementary figure

## scrna per patient comps
plist1 <- default_comp_grid_list(
  filter(comp_tbl_sample, sort_short_x == "CD45-"), 
  cell_type, "Ov cancer cell", cell_type, facet = patient_id_short, 
  site_box = F, vec_plot = F, mutsig_box = F, site_tiles = T, mutsig_tiles = T)

plist2 <- default_comp_grid_list(
  filter(comp_tbl_sample, sort_short_x == "CD45+"), 
  cell_type, "T cell", cell_type, facet = patient_id_short, 
  site_box = F, vec_plot = F, mutsig_box = F, site_tiles = T, mutsig_tiles = T)

cd45n_lvls <- comp_tbl_sample %>% 
  filter(sort_short_x == "CD45-") %>% 
  rank_by(cell_type, "T cell", cell_type) %>% 
  distinct(sample_id, sample_id_lvl)

cd45p_lvls <- comp_tbl_sample %>% 
  filter(sort_short_x == "CD45+") %>% 
  rank_by(cell_type, "Ov cancer cell", cell_type) %>% 
  distinct(sample_id, sample_id_lvl)

pbar_txga1 <- comp_tbl_consOV %>% 
  mutate(alpha_highlight = F) %>% 
  filter(sort_short_x == "CD45-") %>% 
  left_join(cd45n_lvls, by = "sample_id") %>% 
  plot_comp_bar(sample_id_lvl, nrel, consensusOV, facet = patient_id_short) +
  remove_guides

pbar_txga2 <- comp_tbl_consOV %>% 
  mutate(alpha_highlight = F) %>% 
  filter(sort_short_x == "CD45+") %>% 
  left_join(cd45p_lvls, by = "sample_id") %>% 
  plot_comp_bar(sample_id_lvl, nrel, consensusOV, facet = patient_id_short) +
  remove_guides


pgrid_supplement1 <- plot_grid(plist1$pbar1 + 
                                 theme(strip.text.x = element_text(angle = 90)), 
                               plist1$pbar2 + labs(y = "% cells\n(cell type)"), 
                               pbar_txga1 + labs(y = "% cells\n(TCGA)"),
                               plist1$ptiles2, plist1$ptiles1,
                               ncol = 1, align = "v", axis = "x",
                               rel_heights = c(0.3, 0.3, 0.3, 0.05, 0.05))

pgrid_supplement2 <- plot_grid(plist2$pbar1 + 
                                 theme(strip.text.x = element_text(angle = 90)), 
                               plist2$pbar2 + labs(y = "% cells\n(cell type)"), 
                               pbar_txga2 + labs(y = "% cells\n(TCGA)"),
                               plist2$ptiles2, plist2$ptiles1,
                               ncol = 1, align = "v", axis = "x",
                               rel_heights = c(0.3, 0.3, 0.3, 0.05, 0.05))

## mpif per patient comps
mpif_lvls <- mpif_cell_type_n_slide_compartment %>% 
  rank_by(cell_type, "CD8+", cell_type) %>% 
  distinct(sample_id, sample_id_lvl)

plist_mpif <- default_comp_grid_list(
  filter(mpif_cell_type_n_slide_compartment), 
  cell_type, "CD8+", cell_type, facet = patient_id_short, nmax = 750000,
  site_box = F, vec_plot = F, mutsig_box = F, site_tiles = T, mutsig_tiles = T)

pbar_mpif_tumor <- mpif_cell_type_n_slide_compartment %>% 
  filter(sort_short_x == "Tumor") %>% 
  mutate(alpha_highlight = F) %>% 
  left_join(mpif_lvls, by = "sample_id") %>% 
  mutate(cell_type = ordered(cell_type, levels = unique(c("CD8+", names(clrs$cell_type))))) %>% 
  plot_comp_bar(sample_id_lvl, nrel, cell_type, nmax = 750000, 
                facet = patient_id_short) +
  guides(fill = F) +
  labs(y = "% cells\n(Tumor)")

pbar_mpif_stroma <- mpif_cell_type_n_slide_compartment %>% 
  filter(sort_short_x == "Stroma") %>% 
  mutate(alpha_highlight = F) %>% 
  left_join(mpif_lvls, by = "sample_id") %>% 
  mutate(cell_type = ordered(cell_type, levels = unique(c("CD8+", names(clrs$cell_type))))) %>% 
  plot_comp_bar(sample_id_lvl, nrel, cell_type, nmax = 750000, 
                facet = patient_id_short) +
  guides(fill = F) +
  labs(y = "% cells\n(Stroma)")


pgrid_supplement3 <- plot_grid(plist_mpif$pbar1 + 
                                 theme(strip.text.x = element_text(angle = 90)) +
                                 labs(y = "# cells\n"), 
                               pbar_mpif_tumor, pbar_mpif_stroma,
                               plist_mpif$ptiles2, plist_mpif$ptiles1,
                               ncol = 1, align = "v", axis = "x",
                               rel_heights = c(0.3, 0.3, 0.3, 0.05, 0.05))

## all compositions plots combined combined
pgrid_supplement_full <- plot_grid(
  plot_grid(pgrid_supplement1, ggdraw(), nrow = 1, rel_widths = c(1, 0)), 
  ggdraw(), 
  plot_grid(pgrid_supplement2, ggdraw(), nrow = 1, rel_widths = c(1, 0)), 
  ggdraw(),
  plot_grid(pgrid_supplement3, ggdraw(), nrow = 1, rel_widths = c(0.55, 0.45)), 
  ncol = 1, 
  rel_heights = c(0.3, 0.05, 0.3, 0.05, 0.3))

pgrid_supplement_full

ggsave(filename = "_fig/002_cohort/002_comp_grid_per_patient.pdf", 
       pgrid_supplement_full, width = 20, height = 12)
ggsave(filename = "_fig/002_cohort/002_comp_grid_per_patient.png", 
       pgrid_supplement_full, width = 20, height = 12)